{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "843542f7-220a-4408-9f8a-848696092434",
   "metadata": {
    "id": "843542f7-220a-4408-9f8a-848696092434"
   },
   "source": [
    "# Build a Model to generate Synthetic Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8816fc8-9517-46ff-af27-9fd0060840aa",
   "metadata": {},
   "source": [
    "Code was written in Google Colab. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08a8d539-950b-4b58-abf4-f17bd832c0af",
   "metadata": {
    "id": "08a8d539-950b-4b58-abf4-f17bd832c0af"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Ienu-NHTuUlT",
   "metadata": {
    "id": "Ienu-NHTuUlT"
   },
   "outputs": [],
   "source": [
    "!pip install -q gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5e737cd-27b0-4a2e-9a0c-dbb30ce5cdbf",
   "metadata": {
    "id": "c5e737cd-27b0-4a2e-9a0c-dbb30ce5cdbf"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import json\n",
    "from google.colab import userdata\n",
    "\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
    "import torch\n",
    "\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "khD9X5-V_txO",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "khD9X5-V_txO",
    "outputId": "e2b8d8d0-0433-4b5f-c777-a675213a3f4c"
   },
   "outputs": [],
   "source": [
    "!pip install -U bitsandbytes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47ead5f-b4e9-4e9f-acf9-be1ffb7fa6d7",
   "metadata": {
    "id": "e47ead5f-b4e9-4e9f-acf9-be1ffb7fa6d7"
   },
   "outputs": [],
   "source": [
    "hf_token = userdata.get('HF_TOKEN')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba104a9c-f298-4e90-9ceb-9d907e392d0d",
   "metadata": {
    "id": "ba104a9c-f298-4e90-9ceb-9d907e392d0d"
   },
   "source": [
    "## Open Source Models from HF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11b1eb65-8ef5-4e6d-9176-cf1f70d07fb6",
   "metadata": {
    "id": "11b1eb65-8ef5-4e6d-9176-cf1f70d07fb6"
   },
   "outputs": [],
   "source": [
    "deepseek_model = 'deepseek-ai/deepseek-llm-7b-chat'\n",
    "llama_model = 'meta-llama/Meta-Llama-3.1-8B-Instruct'\n",
    "qwen2 = 'Qwen/Qwen2-7B-Instruct'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90fb1d2e-5d25-4d73-b629-8273ab71503c",
   "metadata": {
    "id": "90fb1d2e-5d25-4d73-b629-8273ab71503c"
   },
   "outputs": [],
   "source": [
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52948c01-8dc6-404b-a2c1-c87f9f6dbd64",
   "metadata": {
    "id": "52948c01-8dc6-404b-a2c1-c87f9f6dbd64"
   },
   "source": [
    "## Creating Prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79374337-34fe-4002-b173-ac9b132a54d8",
   "metadata": {
    "id": "79374337-34fe-4002-b173-ac9b132a54d8"
   },
   "outputs": [],
   "source": [
    "system_prompt = \"You are an expert in generating synthetic datasets. Your goal is to generate realistic datasets \\\n",
    "based on a given business and its requirements from the user. You will also be given the desired datset format.\"\n",
    "system_prompt += \"Do not repeat the instructions.\"\n",
    "\n",
    "user_prompt = (\"Please provide me a dataset for the following business.\"\n",
    "\"For example:\\n\"\n",
    "\"The Business: A retail store selling luxury watches.\\n\"\n",
    "\"The Data Format: CSV.\\n\"\n",
    "\"Output:\\n\"\n",
    "\"Item,Price,Quantity,Brand,Sale Date\\n\"\n",
    "\"Superocean II, 20.000$, 3, Breitling, 2025-04-08 \\n\"\n",
    "\"If I don't provide you the necessary columns, please create the columns based on your knowledge about the given business\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcd90b5e-a7d2-4cdc-81ff-17974c5ff1fe",
   "metadata": {
    "id": "dcd90b5e-a7d2-4cdc-81ff-17974c5ff1fe"
   },
   "outputs": [],
   "source": [
    "def dataset_format(data_format, num_records):\n",
    "    format_message = ''\n",
    "    if data_format == 'CSV':\n",
    "        format_message = 'Please provide the dataset in a CSV format.'\n",
    "    elif data_format == 'JSON':\n",
    "        format_message =  'Please provide the dataset in a JSON format'\n",
    "    elif data_format == 'Tabular':\n",
    "        format_message =  'Please provide the dataset in a Tabular format'\n",
    "\n",
    "    return format_message + f'Please generate {num_records} records'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39243edb-3eba-46fd-a610-e474ed421b01",
   "metadata": {
    "id": "39243edb-3eba-46fd-a610-e474ed421b01"
   },
   "outputs": [],
   "source": [
    "def complete_user_prompt(user_input, data_format, num_records):\n",
    "    messages = [\n",
    "        {'role': 'system', 'content': system_prompt},\n",
    "        {'role': 'user', 'content': user_input + user_prompt + dataset_format(data_format, num_records)}\n",
    "    ]\n",
    "\n",
    "    return messages"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ac81127-b9cc-424b-8b38-8a8b09bcc226",
   "metadata": {
    "id": "1ac81127-b9cc-424b-8b38-8a8b09bcc226"
   },
   "source": [
    "## Accessing the Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4aaab5-bde1-463b-b873-e8bd1a231dc1",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cc4aaab5-bde1-463b-b873-e8bd1a231dc1",
    "outputId": "16c9420d-2c4a-4e57-f281-7c531b5145db"
   },
   "outputs": [],
   "source": [
    "print(\"CUDA available:\", torch.cuda.is_available())\n",
    "if torch.cuda.is_available():\n",
    "    print(\"GPU-Device:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n",
    "else:\n",
    "    print(\"No GPU found.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b8e648d-747f-4684-a20b-b8da550efc23",
   "metadata": {
    "id": "6b8e648d-747f-4684-a20b-b8da550efc23"
   },
   "outputs": [],
   "source": [
    "quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit = True,\n",
    "    bnb_4bit_use_double_quant = False,\n",
    "    bnb_4bit_compute_dtype= torch.bfloat16,\n",
    "    bnb_4bit_quant_type= 'nf4'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ae602f-0abf-420d-8c7b-1938cba92528",
   "metadata": {
    "id": "b3ae602f-0abf-420d-8c7b-1938cba92528"
   },
   "outputs": [],
   "source": [
    "def generate_model(model_id, messages):\n",
    "    try:\n",
    "      tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code = True)\n",
    "      inputs = tokenizer.apply_chat_template(messages, return_tensors = 'pt').to('cuda')\n",
    "      streamer = TextStreamer(tokenizer)\n",
    "      model = AutoModelForCausalLM.from_pretrained(model_id, device_map = 'auto', quantization_config = quant_config)\n",
    "      outputs = model.generate(inputs, max_new_tokens = 2000, streamer = streamer)\n",
    "      generated_text = tokenizer.decode(outputs[0], skip_special_tokens = True)\n",
    "      del tokenizer, streamer, model, inputs, outputs\n",
    "      return generated_text\n",
    "\n",
    "    except Exception as e:\n",
    "      return f'Error during generation: {str(e)}'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c575c9e-4674-4eee-a9b9-c8d14ceed474",
   "metadata": {
    "id": "7c575c9e-4674-4eee-a9b9-c8d14ceed474"
   },
   "source": [
    "## Generate Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9c5963e-9f4e-4990-b744-b9ead03e623a",
   "metadata": {
    "id": "d9c5963e-9f4e-4990-b744-b9ead03e623a"
   },
   "outputs": [],
   "source": [
    "def generate_dataset(user_input, target_format, model_choice, num_records):\n",
    "    if model_choice == 'DeepSeek':\n",
    "        model_id = deepseek_model\n",
    "    elif model_choice == 'Llama-3.1-8B':\n",
    "        model_id = llama_model\n",
    "    elif model_choice == 'Qwen2':\n",
    "        model_id = qwen2\n",
    "\n",
    "    messages = complete_user_prompt(user_input, target_format, num_records)\n",
    "    return generate_model(model_id, messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff574cfe-567f-4c6d-b944-fb756bf7ebca",
   "metadata": {
    "id": "ff574cfe-567f-4c6d-b944-fb756bf7ebca"
   },
   "source": [
    "## Creating Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61d2b056-0d00-4b73-b083-024a8f374fef",
   "metadata": {
    "id": "61d2b056-0d00-4b73-b083-024a8f374fef"
   },
   "outputs": [],
   "source": [
    "with gr.Blocks(title = 'Synthetic Data Generator') as ui:\n",
    "    gr.Markdown('# Synthetic Data Generator')\n",
    "\n",
    "    with gr.Row():\n",
    "        with gr.Column(min_width=600):\n",
    "            user_inputs = gr.Textbox(label = 'Enter your Business details and data requirements',\n",
    "                                     placeholder = 'Type here...', lines = 15)\n",
    "\n",
    "            model_choice = gr.Dropdown(\n",
    "                ['DeepSeek', 'Llama-3.1-8B', 'Qwen2'],\n",
    "                label = 'Choose your Model',\n",
    "                value = 'DeepSeek'\n",
    "            )\n",
    "\n",
    "            target_format = gr.Dropdown(\n",
    "                ['CSV', 'JSON', 'Tabular'],\n",
    "                label = 'Choose your Format',\n",
    "                value = 'CSV'\n",
    "            )\n",
    "            num_records = gr.Dropdown(\n",
    "                [50, 100, 150, 200],\n",
    "                label = 'Number of Records',\n",
    "                value = 50\n",
    "            )\n",
    "\n",
    "            generate_button = gr.Button('Generate')\n",
    "\n",
    "        with gr.Column():\n",
    "            output = gr.Textbox(label = 'Generated Synthetic Data',\n",
    "                               lines = 30)\n",
    "\n",
    "    generate_button.click(fn = generate_dataset, inputs = [user_inputs, target_format, model_choice, num_records],\n",
    "                          outputs = output\n",
    "                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "958d9cbf-50ff-4c50-a305-18df6d5f5eda",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 626
    },
    "id": "958d9cbf-50ff-4c50-a305-18df6d5f5eda",
    "outputId": "a6736641-85c3-4b6a-a28d-02ac5caf4562",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ui.launch(inbrowser = True)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
